Skip to content

Add packet-switched flash attention programming example#1532

Draft
erwei-xilinx wants to merge 1 commit into
Xilinx:mainfrom
erwei-xilinx:packet_switched_flash_attention
Draft

Add packet-switched flash attention programming example#1532
erwei-xilinx wants to merge 1 commit into
Xilinx:mainfrom
erwei-xilinx:packet_switched_flash_attention

Conversation

@erwei-xilinx
Copy link
Copy Markdown
Collaborator

Summary

  • Recover the original packet-switched flash attention design (from pre-PR Replace flash attention with memtile-relayed selective-capture implementation #1466) as a standalone example under programming_examples/flash_attention/packet_switched/
  • Uses channel_type="dma_packet" for Q and K routing through shared compute tile S2MM DMA channels, demonstrating hardware packet routing in the stream switch
  • Supports both NPU2 (AIE2P) and NPU1 (AIE2) targets

Channel routing

Channel Type Purpose
L2ToL1Chan1 (Q) dma_packet Broadcast to all compute tiles
L2ToL1Chan2 (K) dma_packet Broadcast to all compute tiles
L2ToL1Chan3 (V) dma_stream Circuit-switched per cascade stage

Files

File Description
attn.py + attn_pkt.cc NPU2 variant (mmul<8,8,8>, n-major B-blocks)
attn_npu1.py + attn_npu1.cc NPU1 variant (mmul<4,8,4>, k-major B-blocks, LUT exp)
Makefile Build targets for both NPU1 (run-npu1) and NPU2 (run)
run_npu1_makefile_peano.lit NPU1 LIT test
run_npu2_makefile_peano.lit NPU2 LIT test

Test plan

  • NPU1 LIT test passes (make run-npu1 with defaults LK=512 LKP=64 LQ=512 LQP=256 DK=64 NUM_HEADS=2)
  • NPU2 LIT test passes (make run with same defaults)
  • Verify dma_packet channel declarations in generated MLIR via make print / make print-npu1

🤖 Generated with Claude Code

Recover the original packet-switched flash attention design as a standalone
example. This design uses dma_packet channels to time-multiplex Q and K
data through shared compute tile S2MM DMA channels via hardware packet
routing in the stream switch.

Channel routing:
  L2ToL1Chan1 (Q): dma_packet — broadcast to all compute tiles
  L2ToL1Chan2 (K): dma_packet — broadcast to all compute tiles
  L2ToL1Chan3 (V): dma_stream — circuit-switched per cascade stage

Includes both NPU2 (AIE2P, attn.py + attn_pkt.cc) and NPU1 (AIE2,
attn_npu1.py + attn_npu1.cc) variants. The NPU1 variant reuses the
kernel from kernel_fusion_based with k-major B-block indexing and
adapted DMA layouts for mmul<4,8,4>.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant